{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# Image classification with Model Garden"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/image_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ta_nFXaVAqLD"
      },
      "source": [
        "This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n",
        "\n",
        "Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
        "\n",
        "This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n",
        "\n",
        "This tutorial demonstrates how to:\n",
        "1. Use models from the TensorFlow Models package.\n",
        "2. Fine-tune a pre-built ResNet for image classification.\n",
        "3. Export the tuned ResNet model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G2FlaQcEPOER"
      },
      "source": [
        "## Setup\n",
        "\n",
        "Install and import the necessary modules."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XvWfdCrvrV5W"
      },
      "outputs": [],
      "source": [
        "!pip install -U -q \"tf-models-official\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CKYMTPjOE400"
      },
      "source": [
        "Import TensorFlow, TensorFlow Datasets, and a few helper libraries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wlon1uoIowmZ"
      },
      "outputs": [],
      "source": [
        "import pprint\n",
        "import tempfile\n",
        "\n",
        "from IPython import display\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AVTs0jDd1b24"
      },
      "source": [
        "The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NHT1iiIiBzlC"
      },
      "outputs": [],
      "source": [
        "import tensorflow_models as tfm\n",
        "\n",
        "# These are not in the tfm public API for v2.9. They will be available in v2.10\n",
        "from official.vision.serving import export_saved_model_lib\n",
        "import official.core.train_lib"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aKv3wdqkQ8FU"
      },
      "source": [
        "## Configure the ResNet-18 model for the Cifar-10 dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5iN8mHEJjKYE"
      },
      "source": [
        "The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n",
        "\n",
        "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
        "\n",
        "Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1M77f88Dj2Td"
      },
      "outputs": [],
      "source": [
        "exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
        "tfds_name = 'cifar10'\n",
        "ds,ds_info = tfds.load(\n",
            "tfds_name,\n",
            "with_info=True)\n",
        "ds_info"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U6PVwXA-j3E7"
      },
      "source": [
        "Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YWI7faVStQaV"
      },
      "outputs": [],
      "source": [
        "# Configure model\n",
        "exp_config.task.model.num_classes = 10\n",
        "exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n",
        "exp_config.task.model.backbone.resnet.model_id = 18\n",
        "\n",
        "# Configure training and testing data\n",
        "batch_size = 128\n",
        "\n",
        "exp_config.task.train_data.input_path = ''\n",
        "exp_config.task.train_data.tfds_name = tfds_name\n",
        "exp_config.task.train_data.tfds_split = 'train'\n",
        "exp_config.task.train_data.global_batch_size = batch_size\n",
        "\n",
        "exp_config.task.validation_data.input_path = ''\n",
        "exp_config.task.validation_data.tfds_name = tfds_name\n",
        "exp_config.task.validation_data.tfds_split = 'test'\n",
        "exp_config.task.validation_data.global_batch_size = batch_size\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DE3ggKzzTD56"
      },
      "source": [
        "Adjust the trainer configuration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "inE_-4UGkLud"
      },
      "outputs": [],
      "source": [
        "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
        "\n",
        "if 'GPU' in ''.join(logical_device_names):\n",
        "  print('This may be broken in Colab.')\n",
        "  device = 'GPU'\n",
        "elif 'TPU' in ''.join(logical_device_names):\n",
        "  print('This may be broken in Colab.')\n",
        "  device = 'TPU'\n",
        "else:\n",
        "  print('Running on CPU is slow, so only train for a few steps.')\n",
        "  device = 'CPU'\n",
        "\n",
        "if device=='CPU':\n",
        "  train_steps = 20\n",
        "  exp_config.trainer.steps_per_loop = 5\n",
        "else:\n",
        "  train_steps=5000\n",
        "  exp_config.trainer.steps_per_loop = 100\n",
        "\n",
        "exp_config.trainer.summary_interval = 100\n",
        "exp_config.trainer.checkpoint_interval = train_steps\n",
        "exp_config.trainer.validation_interval = 1000\n",
        "exp_config.trainer.validation_steps =  ds_info.splits['test'].num_examples // batch_size\n",
        "exp_config.trainer.train_steps = train_steps\n",
        "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
        "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5mTcDnBiTOYD"
      },
      "source": [
        "Print the modified configuration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tuVfxSBCTK-y"
      },
      "outputs": [],
      "source": [
        "pprint.pprint(exp_config.as_dict())\n",
        "\n",
        "display.Javascript(\"google.colab.output.setIframeHeight('300px');\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w7_X0UHaRF2m"
      },
      "source": [
        "Set up the distribution strategy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ykL14FIbTaSt"
      },
      "outputs": [],
      "source": [
        "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
        "\n",
        "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
        "    tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
        "\n",
        "if 'GPU' in ''.join(logical_device_names):\n",
        "  distribution_strategy = tf.distribute.MirroredStrategy()\n",
        "elif 'TPU' in ''.join(logical_device_names):\n",
        "  tf.tpu.experimental.initialize_tpu_system()\n",
        "  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
        "  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
        "else:\n",
        "  print('Warning: this will be really slow.')\n",
        "  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W4k5YH5pTjaK"
      },
      "source": [
        "Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
        "\n",
        "The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6MgYSH0PtUaW"
      },
      "outputs": [],
      "source": [
        "with distribution_strategy.scope():\n",
        "  model_dir = tempfile.mkdtemp()\n",
        "  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)\n",
        "\n",
        "#  tf.keras.utils.plot_model(task.build_model(), show_shapes=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IFXEZYdzBKoX"
      },
      "outputs": [],
      "source": [
        "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
        "  print()\n",
        "  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')\n",
        "  print(f'labels.shape: {str(labels.shape):16}  labels.dtype: {labels.dtype!r}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yrwxnGDaRU0U"
      },
      "source": [
        "## Visualize the training data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "683c255c6c52"
      },
      "source": [
        "The dataloader applies a z-score normalization using \n",
        "`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PdmOz2EC0Nx2"
      },
      "outputs": [],
      "source": [
        "plt.hist(images.numpy().flatten());"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7a8582ebde7b"
      },
      "source": [
        "Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wq4Wq_CuDG3Q"
      },
      "outputs": [],
      "source": [
        "label_info = ds_info.features['label']\n",
        "label_info.int2str(1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8c652a6fdbcf"
      },
      "source": [
        "Visualize a batch of the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZKfTxytf1l0d"
      },
      "outputs": [],
      "source": [
        "def show_batch(images, labels, predictions=None):\n",
        "  plt.figure(figsize=(10, 10))\n",
        "  min = images.numpy().min()\n",
        "  max = images.numpy().max()\n",
        "  delta = max - min\n",
        "\n",
        "  for i in range(12):\n",
        "    plt.subplot(6, 6, i + 1)\n",
        "    plt.imshow((images[i]-min) / delta)\n",
        "    if predictions is None:\n",
        "      plt.title(label_info.int2str(labels[i]))\n",
        "    else:\n",
        "      if labels[i] == predictions[i]:\n",
        "        color = 'g'\n",
        "      else:\n",
        "        color = 'r'\n",
        "      plt.title(label_info.int2str(predictions[i]), color=color)\n",
        "    plt.axis(\"off\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xkA5h_RBtYYU"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10, 10))\n",
        "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
        "  show_batch(images, labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v_A9VnL2RbXP"
      },
      "source": [
        "## Visualize the testing data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AXovuumW_I2z"
      },
      "source": [
        "Visualize a batch of images from the validation dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ma-_Eb-nte9A"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10, 10));\n",
        "for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):\n",
        "  show_batch(images, labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ihKJt2FHRi2N"
      },
      "source": [
        "## Train and evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0AFMNvYxtjXx"
      },
      "outputs": [],
      "source": [
        "model, eval_logs = tfm.core.train_lib.run_experiment(\n",
        "    distribution_strategy=distribution_strategy,\n",
        "    task=task,\n",
        "    mode='train_and_eval',\n",
        "    params=exp_config,\n",
        "    model_dir=model_dir,\n",
        "    run_post_eval=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gCcHMQYhozmA"
      },
      "outputs": [],
      "source": [
        "#  tf.keras.utils.plot_model(model, show_shapes=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L7nVfxlBA8Gb"
      },
      "source": [
        "Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0124f938a1b9"
      },
      "outputs": [],
      "source": [
        "for key, value in eval_logs.items():\n",
        "    if isinstance(value, tf.Tensor):\n",
        "      value = value.numpy()\n",
        "    print(f'{key:20}: {value:.3f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TDys5bZ1zsml"
      },
      "source": [
        "Run a batch of the processed training data through the model, and view the results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GhI7zR-Uz1JT"
      },
      "outputs": [],
      "source": [
        "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
        "  predictions = model.predict(images)\n",
        "  predictions = tf.argmax(predictions, axis=-1)\n",
        "\n",
        "show_batch(images, labels, tf.cast(predictions, tf.int32))\n",
        "\n",
        "if device=='CPU':\n",
        "  plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fkE9locGTBgt"
      },
      "source": [
        "## Export a SavedModel"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9669d08c91af"
      },
      "source": [
        "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AQCFa7BvtmDg"
      },
      "outputs": [],
      "source": [
        "# Saving and exporting the trained model\n",
        "export_saved_model_lib.export_inference_graph(\n",
        "    input_type='image_tensor',\n",
        "    batch_size=1,\n",
        "    input_image_size=[32, 32],\n",
        "    params=exp_config,\n",
        "    checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
        "    export_dir='./export/')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vVr6DxNqTyLZ"
      },
      "source": [
        "Test the exported model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gP7nOvrftsB0"
      },
      "outputs": [],
      "source": [
        "# Importing SavedModel\n",
        "imported = tf.saved_model.load('./export/')\n",
        "model_fn = imported.signatures['serving_default']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GiOp2WVIUNUZ"
      },
      "source": [
        "Visualize the predictions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BTRMrZQAN4mk"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10, 10))\n",
        "for data in tfds.load('cifar10', split='test').batch(12).take(1):\n",
        "  predictions = []\n",
        "  for image in data['image']:\n",
        "    index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n",
        "    predictions.append(index)\n",
        "  show_batch(data['image'], data['label'], predictions)\n",
        "\n",
        "  if device=='CPU':\n",
        "    plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "classification_with_model_garden.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
